""" Functions about the state evolution

"""
import numpy as np
from numba import njit


@njit
def next_state(s, const, r, sigma):
    """ Iterate through to the next state.

    Args:
        s: 1x4 np.array. The current state.

    Returns:
        s_next: 1x4 np.array. The next state.
    """
    e = np.random.normal(loc=0.0, scale=sigma)
    s_next = const + np.array([e, 0, 0, 0]) + r @ s
    return s_next

@njit
def rand_choice_nb(arr, prob):
    """ Sample from a vector of probabilities

    Args:
        arr: A 1D numpy array of values to sample from.
        prob: A 1D numpy array of probabilities for the given samples.
    
    Returns:
        A random sample from the given array with a given probability.
    """
    return arr[np.searchsorted(np.cumsum(prob), np.random.random(), side="right")]

